/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.solr.ltr.model;

import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Deque;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.Explanation;
import org.apache.solr.ltr.feature.Feature;
import org.apache.solr.ltr.feature.FeatureException;
import org.apache.solr.ltr.norm.Normalizer;
import org.apache.solr.util.SolrPluginUtils;

/**
 * A scoring model that computes scores based on the summation of multiple weighted trees. Example
 * models are LambdaMART and Gradient Boosted Regression Trees (GBRT) .
 *
 * <p>Example configuration:
 *
 * <pre>
 * {
 *   "class" : "org.apache.solr.ltr.model.MultipleAdditiveTreesModel",
 *   "name" : "multipleadditivetreesmodel",
 *   "features":[
 *     { "name" : "userTextTitleMatch"},
 *     { "name" : "originalScore"}
 *   ],
 *   "params" : {
 *     "trees" : [
 *       {
 *         "weight" : "1",
 *         "root": {
 *           "feature" : "userTextTitleMatch",
 *           "threshold" : "0.5",
 *           "left" : {
 *             "value" : "-100"
 *           },
 *           "right" : {
 *             "feature" : "originalScore",
 *             "threshold" : "10.0",
 *             "left" : {
 *               "value" : "50"
 *             },
 *             "right" : {
 *               "value" : "75"
 *             }
 *           }
 *         }
 *       },
 *       {
 *         "weight" : "2",
 *         "root" : {
 *           "value" : "-10"
 *         }
 *       }
 *     ]
 *   }
 * }
 * </pre>
 *
 * <p>Training libraries:
 *
 * <ul>
 *   <li><a href="http://sourceforge.net/p/lemur/wiki/RankLib/">RankLib</a>
 * </ul>
 *
 * <p>Background reading:
 *
 * <ul>
 *   <li><a href="http://research.microsoft.com/pubs/132652/MSR-TR-2010-82.pdf">Christopher J.C.
 *       Burges. From RankNet to LambdaRank to LambdaMART: An Overview. Microsoft Research Technical
 *       Report MSR-TR-2010-82.</a>
 * </ul>
 *
 * <ul>
 *   <li><a
 *       href="https://papers.nips.cc/paper/3305-a-general-boosting-method-and-its-application-to-learning-ranking-functions-for-web-search.pdf">
 *       Z. Zheng, H. Zha, T. Zhang, O. Chapelle, K. Chen, and G. Sun. A General Boosting Method and
 *       its Application to Learning Ranking Functions for Web Search. Advances in Neural
 *       Information Processing Systems (NIPS), 2007.</a>
 * </ul>
 */
public class MultipleAdditiveTreesModel extends LTRScoringModel {

  /**
   * fname2index is filled from constructor arguments (that are already part of the base class
   * hashCode) and therefore here it does not individually influence the class hashCode, equals,
   * etc.
   */
  private final HashMap<String, Integer> fname2index;

  /**
   * trees is part of the LTRScoringModel params map and therefore here it does not individually
   * influence the class hashCode, equals, etc.
   */
  private List<RegressionTree> trees;

  private boolean isNullSameAsZero = true;

  private RegressionTree createRegressionTree(Map<String, Object> map) {
    final RegressionTree rt = new RegressionTree();
    if (map != null) {
      SolrPluginUtils.invokeSetters(rt, map.entrySet());
    }
    return rt;
  }

  private RegressionTreeNode createRegressionTreeNode(Map<String, Object> map) {
    final RegressionTreeNode rtn = new RegressionTreeNode();
    if (map != null) {
      SolrPluginUtils.invokeSetters(rtn, map.entrySet());
    }
    return rtn;
  }

  public void setIsNullSameAsZero(boolean nullSameAsZero) {
    isNullSameAsZero = nullSameAsZero;
  }

  public class RegressionTreeNode {
    private static final float NODE_SPLIT_SLACK = 1E-6f;

    private float value = 0f;
    private String feature;
    private int featureIndex = -1;
    private Float threshold;
    private RegressionTreeNode left;
    private RegressionTreeNode right;
    private String missing;

    public void setValue(float value) {
      this.value = value;
    }

    public void setValue(String value) {
      this.value = Float.parseFloat(value);
    }

    public void setMissing(String direction) {
      this.missing = direction;
    }

    public void setFeature(String feature) {
      this.feature = feature;
      final Integer idx = fname2index.get(this.feature);
      // this happens if the tree specifies a feature that does not exist
      // this could be due to lambdaSmart building off of pre-existing trees
      // that use a feature that is no longer output during feature extraction
      featureIndex = (idx == null) ? -1 : idx;
    }

    public void setThreshold(float threshold) {
      this.threshold = threshold + NODE_SPLIT_SLACK;
    }

    public void setThreshold(String threshold) {
      this.threshold = Float.parseFloat(threshold) + NODE_SPLIT_SLACK;
    }

    @SuppressWarnings({"unchecked"})
    public void setLeft(Object left) {
      this.left = createRegressionTreeNode((Map<String, Object>) left);
    }

    @SuppressWarnings({"unchecked"})
    public void setRight(Object right) {
      this.right = createRegressionTreeNode((Map<String, Object>) right);
    }

    public boolean isLeaf() {
      return feature == null;
    }

    @Override
    public String toString() {
      final StringBuilder sb = new StringBuilder();
      if (isLeaf()) {
        sb.append(value);
      } else {
        sb.append("(feature=").append(feature);
        sb.append(",threshold=").append(threshold.floatValue() - NODE_SPLIT_SLACK);
        if (missing != null) {
          sb.append(",missing=").append(missing);
        }
        sb.append(",left=").append(left);
        sb.append(",right=").append(right);
        sb.append(')');
      }
      return sb.toString();
    }

    public RegressionTreeNode() {}
  }

  public class RegressionTree {

    private Float weight;
    private RegressionTreeNode root;

    public void setWeight(float weight) {
      this.weight = weight;
    }

    public void setWeight(String weight) {
      this.weight = Float.valueOf(weight);
    }

    @SuppressWarnings({"unchecked"})
    public void setRoot(Object root) {
      this.root = createRegressionTreeNode((Map<String, Object>) root);
    }

    public float score(float[] featureVector) {
      if (isNullSameAsZero) {
        return weight.floatValue() * scoreNode(featureVector, root);
      } else {
        return weight.floatValue() * scoreNodeWithNullSupport(featureVector, root);
      }
    }

    public String explain(float[] featureVector) {
      return explainNode(featureVector, root);
    }

    @Override
    public String toString() {
      return "(weight=" + weight + ",root=" + root + ")";
    }

    public RegressionTree() {}

    public void validate() throws ModelException {
      if (weight == null) {
        throw new ModelException("MultipleAdditiveTreesModel tree doesn't contain a weight");
      }
      if (root == null) {
        throw new ModelException("MultipleAdditiveTreesModel tree doesn't contain a tree");
      } else {
        validateNode(root);
      }
    }
  }

  @SuppressWarnings({"unchecked"})
  public void setTrees(Object trees) {
    this.trees = new ArrayList<RegressionTree>();
    for (final Object o : (List<Object>) trees) {
      final RegressionTree rt = createRegressionTree((Map<String, Object>) o);
      this.trees.add(rt);
    }
  }

  public MultipleAdditiveTreesModel(
      String name,
      List<Feature> features,
      List<Normalizer> norms,
      String featureStoreName,
      List<Feature> allFeatures,
      Map<String, Object> params) {
    super(name, features, norms, featureStoreName, allFeatures, params);

    fname2index = new HashMap<String, Integer>();
    for (int i = 0; i < features.size(); ++i) {
      final String key = features.get(i).getName();
      fname2index.put(key, i);
    }
  }

  @Override
  protected void validate() throws ModelException {
    super.validate();
    if (trees == null) {
      throw new ModelException("no trees declared for model " + name);
    }
    for (RegressionTree tree : trees) {
      tree.validate();
    }
  }

  @Override
  public void normalizeFeaturesInPlace(float[] modelFeatureValues) {
    normalizeFeaturesInPlace(modelFeatureValues, isNullSameAsZero);
  }

  protected void normalizeFeaturesInPlace(float[] modelFeatureValues, boolean isNullSameAsZero) {
    float[] modelFeatureValuesNormalized = modelFeatureValues;
    if (modelFeatureValues.length != norms.size()) {
      throw new FeatureException("Must have normalizer for every feature");
    }
    if (isNullSameAsZero) {
      for (int idx = 0; idx < modelFeatureValuesNormalized.length; ++idx) {
        modelFeatureValuesNormalized[idx] =
            norms.get(idx).normalize(modelFeatureValuesNormalized[idx]);
      }
    } else {
      for (int idx = 0; idx < modelFeatureValuesNormalized.length; ++idx) {
        if (!Float.isNaN(modelFeatureValuesNormalized[idx])) {
          modelFeatureValuesNormalized[idx] =
              norms.get(idx).normalize(modelFeatureValuesNormalized[idx]);
        }
      }
    }
  }

  @Override
  public float score(float[] modelFeatureValuesNormalized) {
    float score = 0;
    for (final RegressionTree t : trees) {
      score += t.score(modelFeatureValuesNormalized);
    }
    return score;
  }

  private static float scoreNode(float[] featureVector, RegressionTreeNode regressionTreeNode) {
    while (true) {
      if (regressionTreeNode.isLeaf()) {
        return regressionTreeNode.value;
      }
      // unsupported feature (tree is looking for a feature that does not exist)
      if ((regressionTreeNode.featureIndex < 0)
          || (regressionTreeNode.featureIndex >= featureVector.length)) {
        return 0f;
      }

      if (featureVector[regressionTreeNode.featureIndex] <= regressionTreeNode.threshold) {
        regressionTreeNode = regressionTreeNode.left;
      } else {
        regressionTreeNode = regressionTreeNode.right;
      }
    }
  }

  private static float scoreNodeWithNullSupport(
      float[] featureVector, RegressionTreeNode regressionTreeNode) {
    while (true) {
      if (regressionTreeNode.isLeaf()) {
        return regressionTreeNode.value;
      }
      // unsupported feature (tree is looking for a feature that does not exist)
      if ((regressionTreeNode.featureIndex < 0)
          || (regressionTreeNode.featureIndex >= featureVector.length)) {
        return 0f;
      }

      if (featureVector[regressionTreeNode.featureIndex] <= regressionTreeNode.threshold) {
        regressionTreeNode = regressionTreeNode.left;
      } else if (featureVector[regressionTreeNode.featureIndex] > regressionTreeNode.threshold) {
        regressionTreeNode = regressionTreeNode.right;
      } else if (Float.isNaN(featureVector[regressionTreeNode.featureIndex])) {
        switch (regressionTreeNode.missing) {
          case "left":
            regressionTreeNode = regressionTreeNode.left;
            break;
          default:
            regressionTreeNode = regressionTreeNode.right;
        }
      }
    }
  }

  private static void validateNode(RegressionTreeNode regressionTreeNode) throws ModelException {

    // Create an empty stack and push root to it
    Deque<RegressionTreeNode> stack = new ArrayDeque<>();
    stack.push(regressionTreeNode);

    while (!stack.isEmpty()) {
      RegressionTreeNode topStackNode = stack.pop();

      if (topStackNode.isLeaf()) {
        if (topStackNode.left != null || topStackNode.right != null) {
          throw new ModelException(
              "MultipleAdditiveTreesModel tree node is leaf with left="
                  + topStackNode.left
                  + " and right="
                  + topStackNode.right);
        }
        continue;
      }
      if (null == topStackNode.threshold) {
        throw new ModelException("MultipleAdditiveTreesModel tree node is missing threshold");
      }
      if (null == topStackNode.left) {
        throw new ModelException("MultipleAdditiveTreesModel tree node is missing left");
      } else {
        stack.push(topStackNode.left);
      }
      if (null == topStackNode.right) {
        throw new ModelException("MultipleAdditiveTreesModel tree node is missing right");
      } else {
        stack.push(topStackNode.right);
      }
    }
  }

  private static String explainNode(float[] featureVector, RegressionTreeNode regressionTreeNode) {
    final StringBuilder returnValueBuilder = new StringBuilder();
    while (true) {
      if (regressionTreeNode.isLeaf()) {
        returnValueBuilder.append("val: ").append(regressionTreeNode.value);
        return returnValueBuilder.toString();
      }

      // unsupported feature (tree is looking for a feature that does not exist)
      if ((regressionTreeNode.featureIndex < 0)
          || (regressionTreeNode.featureIndex >= featureVector.length)) {
        returnValueBuilder
            .append("'")
            .append(regressionTreeNode.feature)
            .append("' does not exist in FV, Return Zero");
        return returnValueBuilder.toString();
      }

      // could store extra information about how much training data supported
      // each branch and report
      // that here
      if (featureVector[regressionTreeNode.featureIndex] <= regressionTreeNode.threshold) {
        returnValueBuilder
            .append("'")
            .append(regressionTreeNode.feature)
            .append("':")
            .append(featureVector[regressionTreeNode.featureIndex])
            .append(" <= ")
            .append(regressionTreeNode.threshold)
            .append(", Go Left | ");
        regressionTreeNode = regressionTreeNode.left;
      } else if (featureVector[regressionTreeNode.featureIndex] > regressionTreeNode.threshold) {
        returnValueBuilder
            .append("'")
            .append(regressionTreeNode.feature)
            .append("':")
            .append(featureVector[regressionTreeNode.featureIndex])
            .append(" > ")
            .append(regressionTreeNode.threshold)
            .append(", Go Right | ");
        regressionTreeNode = regressionTreeNode.right;
      } else if (Float.isNaN(featureVector[regressionTreeNode.featureIndex])) {
        if (Objects.equals(regressionTreeNode.missing, "left")) {
          returnValueBuilder
              .append("'")
              .append(regressionTreeNode.feature)
              .append("': NaN, Go Left | ");
          regressionTreeNode = regressionTreeNode.left;
        } else {
          returnValueBuilder
              .append("'")
              .append(regressionTreeNode.feature)
              .append("': NaN, Go Right | ");
          regressionTreeNode = regressionTreeNode.right;
        }
      }
    }
  }

  // /////////////////////////////////////////
  // produces a string that looks like:
  // 40.0 = multipleadditivetreesmodel [ org.apache.solr.ltr.model.MultipleAdditiveTreesModel ]
  // model applied to
  // features, sum of:
  // 50.0 = tree 0 | 'matchedTitle':1.0 > 0.500001, Go Right |
  // 'this_feature_doesnt_exist' does not
  // exist in FV, Go Left | val: 50.0
  // -10.0 = tree 1 | val: -10.0
  @Override
  public Explanation explain(
      LeafReaderContext context, int doc, float finalScore, List<Explanation> featureExplanations) {
    final float[] fv = new float[featureExplanations.size()];
    int index = 0;
    for (final Explanation featureExplain : featureExplanations) {
      fv[index] = featureExplain.getValue().floatValue();
      index++;
    }

    final List<Explanation> details = new ArrayList<>();
    index = 0;

    for (final RegressionTree t : trees) {
      final float score = t.score(fv);
      final Explanation p = Explanation.match(score, "tree " + index + " | " + t.explain(fv));
      details.add(p);
      index++;
    }

    return Explanation.match(
        finalScore, toString() + " model applied to features, sum of:", details);
  }

  @Override
  public String toString() {
    final StringBuilder sb = new StringBuilder(getClass().getSimpleName());
    sb.append("(name=").append(getName());
    sb.append(",trees=[");
    for (int ii = 0; ii < trees.size(); ++ii) {
      if (ii > 0) {
        sb.append(',');
      }
      sb.append(trees.get(ii));
    }
    sb.append("])");
    return sb.toString();
  }
}
